{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Recherche sémantique avec FAISS (PyTorch)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Installez les bibliothèques 🤗 Transformers et 🤗 Datasets pour exécuter ce *notebook*."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install datasets evaluate transformers[sentencepiece]\n",
    "!pip install faiss-gpu"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import hf_hub_url\n",
    "\n",
    "data_files = hf_hub_url(\n",
    "    repo_id=\"lewtun/github-issues\",\n",
    "    filename=\"datasets-issues-with-comments.jsonl\",\n",
    "    repo_type=\"dataset\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "issues_dataset = load_dataset(\"json\", data_files=data_files, split=\"train\")\n",
    "issues_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "issues_dataset = issues_dataset.filter(\n",
    "    lambda x: (x[\"is_pull_request\"] == False and len(x[\"comments\"]) > 0)\n",
    ")\n",
    "issues_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "columns = issues_dataset.column_names\n",
    "columns_to_keep = [\"title\", \"body\", \"html_url\", \"comments\"]\n",
    "columns_to_remove = set(columns_to_keep).symmetric_difference(columns)\n",
    "issues_dataset = issues_dataset.remove_columns(columns_to_remove)\n",
    "issues_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "issues_dataset.set_format(\"pandas\")\n",
    "df = issues_dataset[:]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df[\"comments\"][0].tolist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comments_df = df.explode(\"comments\", ignore_index=True)\n",
    "comments_df.head(4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import Dataset\n",
    "\n",
    "comments_dataset = Dataset.from_pandas(comments_df)\n",
    "comments_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comments_dataset = comments_dataset.map(\n",
    "    lambda x: {\"comment_length\": len(x[\"comments\"].split())}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "comments_dataset = comments_dataset.filter(lambda x: x[\"comment_length\"] > 15)\n",
    "comments_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def concatenate_text(examples):\n",
    "    return {\n",
    "        \"text\": examples[\"title\"]\n",
    "        + \" \\n \"\n",
    "        + examples[\"body\"]\n",
    "        + \" \\n \"\n",
    "        + examples[\"comments\"]\n",
    "    }\n",
    "\n",
    "\n",
    "comments_dataset = comments_dataset.map(concatenate_text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "model_ckpt = \"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
    "model = AutoModel.from_pretrained(model_ckpt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda\")\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cls_pooling(model_output):\n",
    "    return model_output.last_hidden_state[:, 0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_embeddings(text_list):\n",
    "    encoded_input = tokenizer(\n",
    "        text_list, padding=True, truncation=True, return_tensors=\"pt\"\n",
    "    )\n",
    "    encoded_input = {k: v.to(device) for k, v in encoded_input.items()}\n",
    "    model_output = model(**encoded_input)\n",
    "    return cls_pooling(model_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding = get_embeddings(comments_dataset[\"text\"][0])\n",
    "embedding.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_dataset = comments_dataset.map(\n",
    "    lambda x: {\"embeddings\": get_embeddings(x[\"text\"]).detach().cpu().numpy()[0]}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embeddings_dataset.add_faiss_index(column=\"embeddings\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = \"How can I load a dataset offline?\"\n",
    "question_embedding = get_embeddings([question]).cpu().detach().numpy()\n",
    "question_embedding.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "scores, samples = embeddings_dataset.get_nearest_examples(\n",
    "    \"embeddings\", question_embedding, k=5\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "samples_df = pd.DataFrame.from_dict(samples)\n",
    "samples_df[\"scores\"] = scores\n",
    "samples_df.sort_values(\"scores\", ascending=False, inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for _, row in samples_df.iterrows():\n",
    "    print(f\"COMMENT: {row.comments}\")\n",
    "    print(f\"SCORE: {row.scores}\")\n",
    "    print(f\"TITLE: {row.title}\")\n",
    "    print(f\"URL: {row.html_url}\")\n",
    "    print(\"=\" * 50)\n",
    "    print()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "Recherche sémantique avec FAISS (PyTorch)",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
